{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural Network On PPU\n",
    "\n",
    ">  如果你还没有阅读《Logistic Regression On PPU》，请先阅读该tutorial。\n",
    "\n",
    "在《Logistic Regression On PPU》中，我们已经展示如何利用Secretflow/PPU将一个明文机器学习训练任务转化为隐私保护任务。\n",
    "\n",
    "在这个tutorial中，我们将继续展示对于一个神经网络模型做类似的转换。\n",
    "\n",
    "我们继续使用相同的数据集[Breast Cancer](https://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+(diagnostic))和相同的预处理函数。\n",
    "\n",
    "第一步，我们依然先构建一个明文的NN模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 明文NN模型\n",
    "\n",
    "### 产生数据\n",
    "\n",
    "这里用的函数和《Logistic Regression On PPU》是相同的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "\n",
    "\n",
    "def load_dataset(return_label=False) -> (np.ndarray, np.ndarray):\n",
    "    features, label = load_breast_cancer(return_X_y=True)\n",
    "\n",
    "    if return_label:\n",
    "        return features[:, 15:], label\n",
    "    else:\n",
    "        return features[:, :15], None\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "def transform(data):\n",
    "    scaler = StandardScaler()\n",
    "    return scaler.fit_transform(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型训练\n",
    "\n",
    "#### 模型定义\n",
    "\n",
    "我们这里选用是一个[MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron)模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Sequence\n",
    "import flax.linen as nn\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    features: Sequence[int]\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "        for feat in self.features[:-1]:\n",
    "            x = nn.relu(nn.Dense(feat)(x))\n",
    "        return nn.Dense(self.features[-1])(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 模型训练\n",
    "\n",
    "以下为模型训练函数，我们将会训练一个三层的MLP模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Sequence\n",
    "import flax.linen as nn\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "def fit_auto_grad(x1,x2,y, n_epochs=10, n_batch=10, step_size=0.01):\n",
    "    x = jax.numpy.concatenate((x1, x2), axis=1)\n",
    "    n_iters=10\n",
    "    xs = jnp.array_split(x, n_iters, axis=0)\n",
    "    ys = jnp.array_split(y, n_iters, axis=0)\n",
    "    features=[30, 8, 1]\n",
    "    model = MLP(features)\n",
    "\n",
    "    params = model.init(jax.random.PRNGKey(1),\n",
    "                                      jax.numpy.ones((n_batch, features[0])))\n",
    "\n",
    "    def loss_func(params, x, y):\n",
    "        pred = model.apply(params, x)\n",
    "\n",
    "        def mse(y, pred):\n",
    "            def squared_error(y, y_pred):\n",
    "                return jnp.multiply(y - y_pred, y - y_pred) / 2.0\n",
    "            return jnp.mean(squared_error(y, pred))\n",
    "\n",
    "        return mse(y, pred)\n",
    "\n",
    "    def body_fun(_, state):\n",
    "        params = state\n",
    "        for (x, y) in zip(xs, ys):\n",
    "            _, grads = jax.value_and_grad(loss_func)(params, x, y)\n",
    "            params = jax.tree_multimap(lambda p, g: p - step_size * g,\n",
    "                                        params, grads)\n",
    "        return params\n",
    "\n",
    "    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)\n",
    "    return params\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 验证模型\n",
    "\n",
    "我们将计算模型的loss。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_loss(params, x1, x2, y):\n",
    "    x= jax.numpy.concatenate((x1, x2), axis=1)\n",
    "    mlp=MLP([30, 8, 1])\n",
    "    pred = mlp.apply(params, x)\n",
    "\n",
    "    def mse(y, pred):\n",
    "        def squared_error(y, y_pred):\n",
    "            return jnp.multiply(y - y_pred, y - y_pred) / 2.0\n",
    "\n",
    "        return jnp.mean(squared_error(y, pred))\n",
    "\n",
    "    return mse(y, pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BUILD Together\n",
    "\n",
    "以下是完整的模型训练过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.14214484\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "\n",
    "x1,_ = load_dataset(False)\n",
    "x2, y = load_dataset(True)\n",
    "\n",
    "x1 = transform(x1)\n",
    "x2 = transform(x2)\n",
    "\n",
    "params = fit_auto_grad(x1, x2, y)\n",
    "\n",
    "print(compute_loss(params, x1, x2, y))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第二步，我们将以上训练任务转化为隐私保护任务。\n",
    "\n",
    "## 隐私保护NN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(PPURuntime pid=111570)\u001b[0m I0303 17:03:40.176662 111570 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=28788.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111570)\u001b[0m I0303 17:03:40.176756 111570 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://k69b13338.eu95sqa:28788 in web browser.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111568)\u001b[0m I0303 17:03:40.200518 111568 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=53280.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111568)\u001b[0m I0303 17:03:40.200621 111568 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://k69b13338.eu95sqa:53280 in web browser.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111570)\u001b[0m I0303 17:03:40.277556 111980 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:53280} (0x7f8e8d9bdb80)\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111570)\u001b[0m I0303 17:03:40.277735 112010 external/com_github_brpc_brpc/src/brpc/socket.cpp:2262] Revived Socket{id=0 addr=127.0.0.1:53280} (0x7f8e8d9bdb80) (Connectable)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(PPURuntime pid=111570)\u001b[0m [2022-03-03 17:03:40.176] [info] [context.cc:58] connecting to mesh, id=root, self=0\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111570)\u001b[0m [2022-03-03 17:03:40.193] [info] [context.cc:83] try_connect to rank 1 not succeed, sleep_for 1000ms and retry.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111568)\u001b[0m [2022-03-03 17:03:40.200] [info] [context.cc:58] connecting to mesh, id=root, self=1\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111570)\u001b[0m [2022-03-03 17:03:41.194] [info] [context.cc:111] connected to mesh, id=root, self=0\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=111568)\u001b[0m [2022-03-03 17:03:41.194] [info] [context.cc:111] connected to mesh, id=root, self=1\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "sf.init(['alice', 'bob'], num_cpus=8, log_to_driver=True)\n",
    "\n",
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n",
    "ppu = sf.PPU(sf.utils.testing.cluster_def(['alice', 'bob']))\n",
    "\n",
    "x1_private, x2_private, y_private = sf.to(alice, x1), sf.to(bob, x2), sf.to(bob, y)\n",
    "\n",
    "x1_ppu = x1_private.to(ppu)\n",
    "x2_ppu = x2_private.to(ppu)\n",
    "y_ppu = y_private.to(ppu)\n",
    "\n",
    "params_ppu = ppu(fit_auto_grad)(x1_ppu, x2_ppu, y_ppu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们来检验一下MPC训练出来的params。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.14270556\n"
     ]
    }
   ],
   "source": [
    "loss = compute_loss(sf.reveal(params_ppu),x1, x2,y)\n",
    "\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上为tutorial的全部内容，你可以发现你只需要很少的改动就可以将一个普通的NN模型训练任务转化为隐私保护版本的训练任务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ray\n",
    "\n",
    "ray.shutdown()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "65688097505482c031030402772b0ac100ff43456359c24dc5c31ad512b134c5"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 ('python3.8')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
